import collections
import logging
import os

import torch
from torch.utils.data import DataLoader

from metric import Metric
from runners.utils import AverageMeter, TimeMeter
from models.Moco import MoCo_Model
import numpy as np
class MainRunner:
    def __init__(self, config):
        self.config = config
        self.num_updates = 0
        self._build_dataset()
        self._build_model()
        self._build_optimizer()
        self.metric = Metric()
        self.RESUME = True


    def train(self):

        lr = 6e-4
        self.lr_scheduler.set_lr(lr)
        start_epoch = -1

        if self.RESUME:
            path_checkpoint = "checkpoints/r2a/main/new-model-93.pt"  
            checkpoint = torch.load(path_checkpoint)  

            self.model.load_state_dict(checkpoint['parameters'])  

            self.optimizer.load_state_dict(checkpoint['optimizer'])  
            start_epoch = checkpoint['epoch']  
            self.num_updates = checkpoint['num_updates']
            self.lr_scheduler.step_update(self.num_updates)


        for epoch in range(start_epoch+1, self.config['train']['max_num_epochs'] + 1):
            logging.info('Start Epoch {}'.format(epoch))
            self.eval_generalization()
   

            logger = self._train_one_epoch(epoch)
            os.makedirs(self.config['train']['model_saved_path'], mode=0o755, exist_ok=True)
            save_path = os.path.join(self.config['train']['model_saved_path'],
                                     'new-model-{}.pt'.format(epoch))
            self._save_model(save_path, epoch)
            print('------------IoU-----------')
            torch.cuda.empty_cache()
            self.metric = Metric(tau=0.5)
            #self.eval_generalization()
            if epoch == 10:
                self.eval_generalization()
            if epoch == 20:
                self.eval_generalization()
            elif epoch == 30:
                self.eval_generalization()
            elif epoch == 35:
                self.eval_generalization()
            elif epoch >= 40:
                self.eval_generalization()
            #self.eval()
            torch.cuda.empty_cache()
            # self.lr_scheduler.step(epoch, val_loss)
            logging.info('=' * 60)

    #@torch.inference_mode()
    def eval(self):
        self.metric.reset()
        self.model.eval()
        import cv2
        import numpy as np
        paint = True
        loss_meter = collections.defaultdict(lambda: AverageMeter())

        def hex2bgr(hex):
            r = int(hex[1:3], 16)
            g = int(hex[3:5], 16)
            b = int(hex[5:7], 16)
            return np.asarray([b, g, r], dtype=np.uint8)

        colors = []

        import matplotlib.colors as mcolors
        for k, v in mcolors.TABLEAU_COLORS.items():
            colors.append((k, hex2bgr(v)))

        video_cnt = {}
        epoch = 100
        for bid, batch in enumerate(self.test_loader, 1):

            net_input = move_to_cuda(batch['net_input'])
            output, clip_loss= self.model(**net_input, epoch = epoch)


            iou, ori_dist, pred_mask = self.metric(**output, gt=batch['ori_fine_gt_mask'],
                                                   target_frame=batch['target_frame'])
            if paint:
                video_info = batch['video_info']
                for a, b in enumerate(iou):
                    # print("a:",a)
                    # print("ori_dist:",ori_dist)
                    # if b < 0.7:
                    #     continue
                    video_id, instance_id, frame_idx, query, _ = video_info[a]
                    if video_id == 'xrTQfifI0No':
                        if video_id not in video_cnt:
                            video_cnt[video_id] = []
                        try:
                            ptr = video_cnt[video_id].index(query)
                        except ValueError:
                            ptr = len(video_cnt[video_id])
                            video_cnt[video_id].append(query)
                        print(video_id, frame_idx, instance_id, query, b)
                        # print(ori_dist.shape, a)
                        score_map, mask = _, pred_mask[a][:, :, np.newaxis]

                        print("ptr:", ptr)
                        print("color:", colors)
                        print(colors[ptr][1])
                        save_path = os.path.join('bad_a2d', video_id)
                        save_path1 = os.path.join(save_path, '{:05d}.jpg'.format(int(frame_idx)))
                        if not os.path.exists(save_path):
                            os.makedirs(save_path)
                        if os.path.exists(save_path1):
                            mat = cv2.imread(save_path1)
                        else:
                            mat = np.zeros([mask.shape[0], mask.shape[1], 3], dtype=np.uint8)
                        mat = mat * (1 - mask) + np.reshape(colors[ptr][1], [1, 1, 3]) * mask
                        cv2.imwrite(save_path1, mat)


                    # save_path2 = os.path.join('a2d_heat', video_id)
                    # os.makedirs(save_path2, exist_ok=True)
                    # dist = ori_dist[a]
                    # max_dist, min_dist = np.max(dist), np.min(dist)
                    # dist = (dist - min_dist) / (max_dist - min_dist + 1e-10)
                    # gray_img = dist * 255
                    # norm_img = np.asarray(gray_img, dtype=np.uint8)
                    # # print(norm_img)
                    # kernel = np.ones([10, 10], dtype="float32") / 100.0
                    # norm_img = cv2.filter2D(gray_img, -1, kernel)
                    # norm_img = np.asarray(norm_img, dtype=np.uint8)
                    # # print(norm_img.shape)
                    # target_frame = batch['target_frame'][a]
                    # heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET)  # 注意此处的三通道热力图是cv2专有的GBR排列
                    # # print(heat_img)
                    # img_add = cv2.addWeighted(target_frame, 0.3, heat_img, 0.7, 0)
                    # cv2.imwrite(os.path.join(save_path2, '{:05d}.jpg'.format(int(frame_idx))), img_add)

                # exit(0)
        print(
            '｜ num {} ｜ mean IoU {:.4f} ｜ overall IoU {:.4f} ｜ mAP[0.5:0.95] {:.4f} | mAP[0.1:0.55] {:.4f} |'.format(
                self.metric.num_samples,
                self.metric.mean_iou(), self.metric.overall_iou(),
                self.metric.average_precision(),
                self.metric.average_precision1()))
        for k in range(1, 10):
            print('｜ P@{} {:.4f} '.format(k, self.metric.precision(k)), end=' ')
        print('｜')
        for k, v in loss_meter.items():
            print('{}: {}, '.format(k, v.avg), end='')
        torch.cuda.empty_cache()
        return loss_meter['margin_loss'].avg

    def eval_generalization(self):
        self.metric.reset()
        self.model.eval()
        import cv2
        import numpy as np
        paint = True
        loss_meter = collections.defaultdict(lambda: AverageMeter())

        def hex2bgr(hex):
            r = int(hex[1:3], 16)
            g = int(hex[3:5], 16)
            b = int(hex[5:7], 16)
            return np.asarray([b, g, r], dtype=np.uint8)

        colors = []

        import matplotlib.colors as mcolors
        for k, v in mcolors.TABLEAU_COLORS.items():
            colors.append((k, hex2bgr(v)))

        video_cnt = {}
        epoch = 100
        for bid, batch in enumerate(self.generalization_loader, 1):
            # print('?????')
            # logging.info(bid)
            net_input = move_to_cuda(batch['net_input'])
            output, clip_loss = self.model(**net_input, epoch = epoch)


            iou, ori_dist, pred_mask = self.metric(**output, gt=batch['ori_fine_gt_mask'],
                                                   target_frame=batch['target_frame'])
            if paint:
                video_info = batch['video_info']
                for a, b in enumerate(iou):
                    # if b < 0.8:
                    #     continue
                    # print("a:", a)
                    # print("ori_dist:", ori_dist)

                    video_id, instance_id, frame_idx, query, _ = video_info[a]
                    if video_id == '-2akYw9VucA':
                    #if True:
                        if video_id not in video_cnt:
                            video_cnt[video_id] = []
                        try:
                            ptr = video_cnt[video_id].index(query)
                        except ValueError:
                            ptr = len(video_cnt[video_id])
                            video_cnt[video_id].append(query)
                        print(video_id, frame_idx, instance_id, query, b)
                        # print(ori_dist.shape, a)
                        score_map, mask = _, pred_mask[a][:, :, np.newaxis]
                        # if video_id != 'zIQov-TcS5k':
                        #     # print('not find', video_info[a][0])
                        #     continue
                        # else:
                        #     print('find', video_id)

                        save_path = os.path.join('openset_specific_a2d', video_id)
                        save_path1 = os.path.join(save_path, '{:05d}.jpg'.format(int(frame_idx)))
                        if not os.path.exists(save_path):
                            os.makedirs(save_path)
                        if os.path.exists(save_path1):
                            mat = cv2.imread(save_path1)
                        else:
                            mat = np.zeros([mask.shape[0], mask.shape[1], 3], dtype=np.uint8)
                        mat = mat * (1 - mask) + np.reshape(colors[ptr][1], [1, 1, 3]) * mask
                        cv2.imwrite(save_path1, mat)



        print("***********************************eval generalization results********************************")
        print(
            '｜ num {} ｜ mean IoU {:.4f} ｜ overall IoU {:.4f} ｜ mAP[0.5:0.95] {:.4f} | mAP[0.1:0.55] {:.4f} |'.format(
                self.metric.num_samples,
                self.metric.mean_iou(), self.metric.overall_iou(),
                self.metric.average_precision(),
                self.metric.average_precision1()))
        for k in range(1, 10):
            print('｜ P@{} {:.4f} '.format(k, self.metric.precision(k)), end=' ')
        print('｜')
        for k, v in loss_meter.items():
            print('{}: {}, '.format(k, v.avg), end='')
        torch.cuda.empty_cache()
        print("***********************************eval generalization results********************************")
        return loss_meter['margin_loss'].avg


    def _train_one_epoch(self, epoch):
        def print_log():
            curr_lr = self.lr_scheduler.optimizer.get_lr()
            msg = 'Epoch {}, Batch {}, lr = {:.5f}, '.format(epoch, bid, curr_lr)
            for k, v in loss_meter.items():
                msg += '{} = {:.4f}, '.format(k, v.avg)
                v.reset()
            msg += '{:.3f} seconds/batch'.format(1.0 / time_meter.avg)
            logging.info(msg)

        self.model.train()

        from fairseq.utils import move_to_cuda
        display_n_batches, bid = 100, 0
        time_meter = TimeMeter()
        loss_meter = collections.defaultdict(lambda: AverageMeter())
        self.optimizer.zero_grad()
        acc_steps = 1
        for bid, batch in enumerate(self.train_loader, 1):

            net_input = move_to_cuda(batch['net_input'])
            if epoch >= 0:
                output, clip_loss = self.model(**net_input, epoch=epoch)
                # NCE_loss = self.moco(clip_image_feature.cuda(), text_features.cuda()).mean(0)
            else:
                output, clip_loss = self.model(**net_input, epoch=epoch)
            loss, loss_dict, _ = self.loss_fn(output, full=None)
            loss = loss / acc_steps
            # loss, loss_dict = self.loss_fn(**output1)
            # if epoch >= 0:
            #     loss = loss


            if epoch >= 0:
                loss = loss + clip_loss.mean(0)
            loss.backward()


            # update
            if bid % acc_steps == 0:
                self.optimizer.step()
                self.num_updates += 1
                curr_lr = self.lr_scheduler.step_update(self.num_updates)
                self.optimizer.zero_grad()

            time_meter.update()
            for k, v in loss_dict.items():
                loss_meter[k].update(v)
            #loss_meter.pop("loss20")
            loss_meter.pop("loss40")
            loss_meter.pop("loss160")
            loss_meter['prompt'].update(clip_loss.mean(0))
            #loss_meter['OT'].update(ot_loss.mean(0))

            if bid % display_n_batches == 0:
                print_log()

        if bid % display_n_batches != 0:
            print_log()
        return None

    def _save_model(self, path, epoch):
        state_dict = {
            'parameters': self.model.state_dict(),
            'num_updates': self.num_updates,
            'epoch': epoch,
            'optimizer': self.optimizer.state_dict(),
        }
        torch.save(state_dict, path)
        logging.info('saved model to {}'.format(path))

    def _load_model(self, path):
        # logging.info('load model from {}'.format(path))
        state_dict = torch.load(path)
        self.num_updates = state_dict['num_updates']
        self.lr_scheduler.step_update(self.num_updates)
        logging.info('load model from {}, {}'.format(
            path, self.model.load_state_dict(state_dict['parameters'])))
        # for n, p in self.model.named_parameters():
        #     if 'Conv3d_' in n or 'Mixed_' in n or 'i3d' in n:
        #         p.requires_grad = False
        #         print(n)

    def _build_dataset(self):
        import datasets
        batch_size = self.config['train']['batch_size']
        self.dataset = getattr(datasets, self.config['dataset']['name'])(self.config['dataset'])
        self.generalization_dataset = getattr(datasets, self.config['generalization_dataset']['name'])(self.config['generalization_dataset'])
        self.train_loader = DataLoader(self.dataset.train_set, batch_size=batch_size, shuffle=True,
                                       collate_fn=self.dataset.collate_fn,
                                       num_workers=8, pin_memory=True)
        self.test_loader = DataLoader(self.generalization_dataset.test_set, batch_size=8, shuffle=False,
                                      collate_fn=self.dataset.collate_fn, num_workers=8,
                                      pin_memory=True)
        self.generalization_loader = DataLoader(self.generalization_dataset.test_set, batch_size=4, shuffle=False,
                                                collate_fn=self.generalization_dataset.collate_fn,num_workers=8,
                                                pin_memory=True)

    def _build_model(self):
        import models
        import losses
        from losses import Sum_image_loss
        device_ids = list(range(len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))))
        logging.info('GPU: {}'.format(device_ids))
        self.model = getattr(models, self.config['model']['name'], None)(self.config['model'])
        #self.moco = MoCo_Model()
        # print(self.model)
        self.model.load_pretrained_weights()
        self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
        self.model = self.model.cuda(device_ids[0])
        #self.moco = self.moco.cuda(device_ids[0])
        self.device_ids = device_ids

        self.loss_fn = getattr(losses, self.config['loss'], None)()
        self.sum_image_loss = Sum_image_loss()
        print(self.loss_fn)
        # self.eval_loss_fn = getattr(losses, 'acca_loss', None)

    def _build_optimizer(self):
        from optimizers import AdamOptimizer
        import optimizers.lr_schedulers as lr_schedulers
        parameters = list(self.model.module.parameters())
        args = self.config['train']
        self.optimizer = AdamOptimizer(args, parameters)
        self.lr_scheduler = \
            getattr(lr_schedulers, args['lr_scheduler'])(args, self.optimizer)
        # print('==============================')
        # print(self.optimizer)
        # print('==============================')


def apply_to_sample(f, sample):
    if len(sample) == 0:
        return {}

    def _apply(x):
        if torch.is_tensor(x):
            return f(x)
        elif isinstance(x, dict):
            return {
                key: _apply(value)
                for key, value in x.items()
            }
        elif isinstance(x, list):
            return [_apply(x) for x in x]
        else:
            return x

    return _apply(sample)


def move_to_cuda(sample):
    def _move_to_cuda(tensor):
        return tensor.cuda()

    return apply_to_sample(_move_to_cuda, sample)



